#!/usr/bin/env python3
# coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""common"""
import akg.tvm
from .elewise_compute import vmuls, vadds, vmax, vmin, vabs, vrec, vmul, set_is_need_save_dtype
from .cast_compute import floor, round, cast


def fargmax(x, y):
    """
    Build expression for the index of maximum value among input expressions x and y.

    Args:
        x (tvm.expr.Expr): Input expression.
        y (tvm.expr.Expr): Input expression.

    Returns:
        tvm.expr.Expr. The call expression.

    Examples:
        >>> n = akg.tvm.var('n')
        >>> m = akg.tvm.var('m')
        >>> data = akg.tvm.placeholder((n, m), name='data')
        >>> k = akg.tvm.reduce_axis((0, m), "k")
        >>> reducer = akg.tvm.comm_reducer(lambda x,y: akg.fargmax(x, y), lambda t: akg.tvm.min_value(t), name="argmax")
        >>> res = akg.tvm.compute((n,), lambda *indice: reducer(data(*indice, k), axis=k), name="res")
    """
    return akg.tvm.call_pure_intrin(x.dtype, "fargmax", x, y)


def fargmin(x, y):
    """
    Build expression for the index of minimum value among input expressions x and y.

    Args:
        x (tvm.expr.Expr): Input expression.
        y (tvm.expr.Expr): Input expression.

    Returns:
        tvm.expr.Expr. The call expression.
    """
    return akg.tvm.call_pure_intrin(x.dtype, "fargmin", x, y)


def mad(x, y):
    """
    Build expression for two matrices multiplication and add.

    Args:
        x (tvm.expr.Expr): Input expression.
        y (tvm.expr.Expr): Input expression.

    Returns:
        tvm.expr.Expr. The call expression.

    Examples:
        >>> n = akg.tvm.var('n')
        >>> m = akg.tvm.var('m')
        >>> k = akg.tvm.var('k')
        >>> A = akg.tvm.placeholder((m, k), name='A')
        >>> B = akg.tvm.placeholder((k, n), name='B')
        >>> kk = akg.tvm.reduce_axis((0, k), name='kk')
        >>> mmad = akg.tvm.comm_reducer(lambda x, y: akg.mad(x, y), lambda t: akg.tvm.const(0, dtype=t), name="mmad")
        >>> C = akg.tvm.compute((m, n), lambda i, j: mmad(A[i, kk] * B[kk, j], axis=kk), name="C")
    """
    return akg.tvm.call_pure_intrin(x.dtype, "mad", x, y)


mmad = akg.tvm.comm_reducer(lambda x, y: mad(x, y), lambda t: akg.tvm.const(0, dtype=t), name="mmad")


def dropout(x, y):
    """
    Build expression with dropout function.

    Args:
        x (tvm.expr.Expr): Input expression.
        y (tvm.expr.Expr): Input expression.

    Returns:
        tvm.expr.Expr. The call expression.
    """
    return akg.tvm.call_pure_intrin(y.dtype, "dropout", x, y)


def iou(x, y):
    """
    Return the intersection over union of x, y box.

    Args:
        x (tvm.expr.Expr): Input expression.
        y (tvm.expr.Expr): Input expression.

    Returns:
        tvm.expr.Expr. The call expression.
    """
    return akg.tvm.call_pure_intrin(x.dtype, "iou", x, y)


def nms(x, y, scalar):
    """
    return nonmaximum suppresion result x, y box.

    Args:
        x (tvm.expr.Expr): Input argument of reduced tensor.
        y (tvm.expr.Expr): Input argument.
        scalar (Union[tvm.expr.Expr, float]): Score threshold of nms.

    Returns:
        z : tvm.expr.Expr. The result is store in fp16, each fp16 is a hex number indicating suppresion.
    """
    return akg.tvm.call_pure_intrin(x.dtype, "nms", x, y, scalar)


def topk_sort(dst, src, topk):
    """
    sort the proposal box and return topk result, used when the sort process need partition the sorting loop.

    Args:
        dst (tvm.expr.Expr): Input argument. The destination of sort generated by common reducer.
        src (tvm.expr.Expr): Input argument.
            Strictly required that the box number can be divisible by 16 and item number is 8.
        topk (tvm.expr.Expr): Input argument. Constant tvm.expr.Expr indicating the required topk number.

    Returns:
        z : tvm.expr.Expr. The result.
    """
    return akg.tvm.call_pure_intrin(src.dtype, "topk_sort", dst, src, topk)


def proposal_sort(dst, src, topk):
    """
    sort the proposal box and return topk result.

    Args:
        dst (tvm.expr.Expr): Input argument. The destination of sort generated by common reducer.
        src (tvm.expr.Expr): Input argument.
            Strictly required that the box number can be divisible by 16 and item number is 8.
        topk (tvm.expr.Expr): Input argument. Constant tvm.expr.Expr indicating the required topk number.

    Returns:
        z : tvm.expr.Expr. The result.
    """
    return akg.tvm.call_pure_intrin(src.dtype, "proposal_sort", dst, src, topk)


def fnot(x):
    return akg.tvm.call_pure_intrin(x.dtype, "not", x)


def f_all(x, y):
    return akg.tvm.call_pure_intrin(x.dtype, "vand", x, y)


all_op = akg.tvm.comm_reducer(lambda x, y: f_all(x, y), lambda t: akg.tvm.const(1, dtype=t), name='all_op')


def round_to(data, max_, min_):
    """
    round data to [min,max]

    Args:
        data (Tensor): tensors need to change dtype.
        max_ (float): the range of res.
        min_ (float): the range of res.

    Returns:
        tensor : akg.tvm.tensor ,elements in tensor is in range [min,max]
    """
    data_tmp = vmuls(data, 0)
    data_min = vadds(data_tmp, min_)
    data_max = vadds(data_tmp, max_)
    data1 = vmax(data, data_min)
    data1 = vmin(data1, data_max)
    return data1


def cast_to(data, dtype, f1628_int_flag=False):
    """
    a wrapped cast operations , cast data to the type of dtype

    Args:
        data (Tensor): akg.tvm.tensor needs to change dtype.
        dtype (String): dst dtype need to cast to.
        f1628_int_flag (bool): before fp16->int8/uint8, the data is all interger or not. default value is False.

    Returns:
        tensor : akg.tvm.tensor.
    """
    if isinstance(data, akg.tvm.tensor.Tensor):
        data_dtype = getattr(data, 'dtype')
    else:
        raise RuntimeError("The cast input type must be akg.tvm.tensor")

    if (data_dtype == "float16") and (dtype == "int32"):
        fp16_max = akg.tvm.const(32768, dtype="float16")
        fp16_min = akg.tvm.const(2 ** (-15), dtype="float16")

        data1 = round_to(data, 0.5, -0.5)

        new_data = vmuls(data1, fp16_max)
        tmp2 = vabs(new_data)
        tmp3 = vadds(tmp2, fp16_min)
        fp16_res = vmul(new_data, vrec(tmp3))
        sign_res = round(fp16_res)

        floor_data = floor(vabs(data))
        res = vmul(floor_data, sign_res)
        return res
    if data_dtype == "float16" and dtype in ("int8", "uint8") and not f1628_int_flag:
        fp16_half = akg.tvm.const(-0.5, dtype="float16")
        set_is_need_save_dtype()
        data = vadds(data, fp16_half)

    if data_dtype == dtype:
        return data
    if data_dtype == "float16":
        tmp = data
    else:
        tmp = cast(data, dst_dtype="float16")
    return cast(tmp, dst_dtype=dtype)


def four2five_nchw(data):
    return akg.tvm.call_pure_intrin(data.dtype, "four2five_nchw", data)

def load3d_l1_ub(data, pad_h, pad_t, pad_l, pad_r,
                 fm_h, fm_w, stride_h, stride_w,
                 filter_h, filter_w, dilation_h, dilation_w, repeat_mode, jmp_offset):
    return akg.tvm.call_pure_intrin(data.dtype, "load3d_l1_ub", data, pad_h, pad_t, pad_l, pad_r,
                                    fm_h, fm_w, stride_h, stride_w,
                                    filter_h, filter_w, dilation_h, dilation_w, repeat_mode, jmp_offset)

def sin(data):
    return akg.tvm.call_pure_intrin(data.dtype, "sin", data)

def cos(data):
    return akg.tvm.call_pure_intrin(data.dtype, "cos", data)

def sinh(data):
    return akg.tvm.call_pure_intrin(data.dtype, "sinh", data)

def cosh(data):
    return akg.tvm.call_pure_intrin(data.dtype, "cosh", data)

def divide_var(data, divisor):
    return akg.tvm.call_pure_intrin(data.dtype, "divide_var", data, divisor)

def vmadd(x, y, z):
    """
    Call the vmadd instruction to calculate :math:`x * y + z`.

    Args:
        x (tvm.tensor.Tensor): input x.
        y (tvm.tensor.Tensor): input y.
        z (tvm.tensor.Tensor): input z.

    Returns:
        tensor : akg.tvm.tensor.
    """
    return akg.tvm.call_pure_intrin(x.dtype, "vmadd", y, z, x)

def vmla(x, y, z):
    """
    Call the vmla instruction to calculate :math:`x + y * z`.

    Args:
        x (tvm.tensor.Tensor): input x.
        y (tvm.tensor.Tensor): input y.
        z (tvm.tensor.Tensor): input z.

    Returns:
        tensor : akg.tvm.tensor.
    """
    return akg.tvm.call_pure_intrin(x.dtype, "vmla", y, z, x)
